import logging
from typing import Optional

from impacket.dcerpc.v5.dcom import wmi
from impacket.dcerpc.v5.dcomrt import DCOMConnection
from impacket.dcerpc.v5.dtypes import NULL
from pydantic import SecretStr

from common.credentials import Credentials, LMHash, NTHash, Password
from infection_monkey.i_puppet import TargetHost

logger = logging.getLogger(__name__)


def secret_of_type(credentials, type) -> Optional[SecretStr]:
    if type is Password and isinstance(credentials.secret, Password):
        return credentials.secret.password
    elif type is LMHash and isinstance(credentials.secret, LMHash):
        return credentials.secret.lm_hash
    elif type is NTHash and isinstance(credentials.secret, NTHash):
        return credentials.secret.nt_hash
    else:
        return None


def get_plaintext(secret: Optional[SecretStr]) -> str:
    if secret is None:
        return ""
    return secret.get_secret_value()


class WMIClient:
    def __init__(self):
        self._wbem_services: Optional[wmi.IWbemServices] = None
        self._connected = False
        self._dcom: Optional[DCOMConnection] = None

    def connected(self) -> bool:
        return self._connected

    def login(self, host: TargetHost, credentials: Credentials):
        """
        Login to the remote host

        :param host: Remote host
        :param credentials: Credentials to use
        :raises Exception: If login fails
        """
        # Impacket has a hard-coded timeout of 120 seconds
        try:
            self._dcom = DCOMConnection(
                str(host.ip),
                username=credentials.identity.username,  # type: ignore[union-attr]
                password=get_plaintext(secret_of_type(credentials, Password)),
                domain=str(host.ip),
                lmhash=get_plaintext(secret_of_type(credentials, LMHash)),
                nthash=get_plaintext(secret_of_type(credentials, NTHash)),
                oxidResolver=True,
            )
            iInterface = self._dcom.CoCreateInstanceEx(
                wmi.CLSID_WbemLevel1Login, wmi.IID_IWbemLevel1Login
            )
        except Exception:
            try:
                self._dcom.disconnect()  # type: ignore[union-attr]
            except KeyError:
                logger.exception("Disconnecting the DCOMConnection failed")

            raise

        iWbemLevel1Login = wmi.IWbemLevel1Login(iInterface)

        try:
            self._wbem_services = iWbemLevel1Login.NTLMLogin("//./root/cimv2", NULL, NULL)
        except Exception:
            self._dcom.disconnect()

            raise
        finally:
            iWbemLevel1Login.RemRelease()

        self._connected = True

    def execute_remote_process(self, command, path) -> bool:
        """
        Execute a process on the remote host

        :param command: Command to execute
        :param path: Path to working directory
        :return: True if the process was executed successfully, False otherwise
        """
        if self._wbem_services is None:
            raise RuntimeError("Not connected")
        win32_process, _ = self._wbem_services.GetObject("Win32_Process")

        # TODO: Is the path the working directory for the process? If so, we may
        # not need to use the dropper at all.
        result = win32_process.Create(command, path, None)
        return (result.ProcessId != 0) and (result.ReturnValue == 0)

    def __del__(self):
        if self._wbem_services:
            self._wbem_services.RemRelease()
        self._wbem_services = None

        if self._dcom:
            self._dcom.disconnect()
        self._dcom = None
        WMIClient._dcom_cleanup()

    @staticmethod
    def _dcom_cleanup():
        for port_map in list(DCOMConnection.PORTMAPS.keys()):
            del DCOMConnection.PORTMAPS[port_map]
        for oid_set in list(DCOMConnection.OID_SET.keys()):
            del DCOMConnection.OID_SET[oid_set]

        DCOMConnection.OID_SET = {}
        DCOMConnection.PORTMAPS = {}
        if DCOMConnection.PINGTIMER:
            DCOMConnection.PINGTIMER.cancel()
            DCOMConnection.PINGTIMER.join()
            DCOMConnection.PINGTIMER = None
